{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RHt3rGLLxfWs"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torchvision"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "-gBuXTWqxuMV",
        "outputId": "42ca4293-2496-4a03-ced5-a02d9253c721"
      },
      "outputs": [],
      "source": [
        "torch.__version__, torchvision.__version__"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "gjdq8wLcQWd7",
        "outputId": "8c5ed594-021c-4c67-e0df-949151d1247f"
      },
      "outputs": [],
      "source": [
        "!pip install -q git+https://github.com/google-deepmind/tapnet.git"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "mrQ_uQDeQee-",
        "outputId": "ce0d3868-7b82-4e8e-88ae-fd5ae960f95b"
      },
      "outputs": [],
      "source": [
        "!pip install -q git+https://github.com/google-deepmind/recurrentgemma.git@main"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "NIZQxfcgyFM9",
        "outputId": "8cd10428-69a7-4c39-f3f9-eec9c07ed108"
      },
      "outputs": [],
      "source": [
        "!pip install \"numpy\u003c2.1.0\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gju7QkZH2XLL"
      },
      "outputs": [],
      "source": [
        "import tqdm"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uQ6Ako7IQERX"
      },
      "source": [
        "### TAPNext"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hy98wCMgQx6x"
      },
      "outputs": [],
      "source": [
        "import time\n",
        "import numpy as np\n",
        "from tapnet.tapnext.tapnext_torch import TAPNext\n",
        "import torch.nn.functional as F"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gJpHisHSQERc"
      },
      "source": [
        "### Create the model and load checkpoint"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pnTma4NKQERc"
      },
      "outputs": [],
      "source": [
        "model = TAPNext(image_size=(256, 256))\n",
        "model.cuda()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CukzSyYSQERd"
      },
      "source": [
        "### Run inference"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 477
        },
        "id": "YNf98GICtJMa",
        "outputId": "05da0eef-b914-4d1b-f740-5152f47ddb4c"
      },
      "outputs": [],
      "source": [
        "model.eval()\n",
        "for p in model.parameters():\n",
        "  p.requires_grad = False"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IUowukX9Qx6x"
      },
      "outputs": [],
      "source": [
        "NUM_QUERIES = 1024"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TyiVKz9MQx6x"
      },
      "outputs": [],
      "source": [
        "video = torch.zeros((1, 1024, 256, 256, 3), dtype=torch.float32).cuda()\n",
        "query_points = torch.zeros((1, NUM_QUERIES, 3), dtype=torch.float32).cuda()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vLRpvleUQx6x"
      },
      "outputs": [],
      "source": [
        "DTYPE = torch.float16  # use fp16 or bf16"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "n6x1NxPaQx6x"
      },
      "outputs": [],
      "source": [
        "fwd = torch.compile(model.forward)\n",
        "with torch.no_grad():\n",
        "  with torch.amp.autocast('cuda', dtype=DTYPE, enabled=True):\n",
        "    _, _, _, tracking_state = fwd(video=video[:, :1], query_points=query_points)\n",
        "    c = 0\n",
        "    for k in tqdm.tqdm(range(1, video.shape[1])):\n",
        "      if k == 512:\n",
        "        # we let the model to run for several GPU burn-in steps\n",
        "        tt = time.time()\n",
        "        c = 0\n",
        "      _, _, _, tracking_state = fwd(\n",
        "          video=video[:, k : k + 1], state=tracking_state\n",
        "      )\n",
        "      c += 1\n",
        "    d = time.time() - tt\n",
        "    print('FPS:', c / d, 'latency', 1000 * d / c, 'ms')"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.11"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 4
}
